import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay



class ConvBNReLU(nn.Layer):
    '''Module for the Conv-BN-ReLU tuple.'''
    def __init__(self, c_in, c_out, kernel_size, stride, padding, dilation,
                 norm_decay=0., use_relu=True, freeze_norm=False):
        super(ConvBNReLU, self).__init__()
        self.use_relu = use_relu
        self.conv = nn.Conv2D(c_in, c_out, kernel_size, stride, padding, dilation, bias_attr=False)

        norm_lr = 0. if freeze_norm else 1.
        param_attr = ParamAttr(
            learning_rate=norm_lr,
            regularizer=L2Decay(norm_decay),
            initializer=nn.initializer.Constant(value=1.0))
        bias_attr = ParamAttr(
            learning_rate=norm_lr,
            regularizer=L2Decay(norm_decay))

        self.normlayer = nn.BatchNorm2D(num_features=c_out, momentum=0.997, epsilon=1e-4,
                                        weight_attr=param_attr, bias_attr=bias_attr)

        if self.use_relu:
            self.relu = nn.ReLU()
        else:
            self.relu = None

    def forward(self, x):
        x = self.conv(x)
        x = self.normlayer(x)
        if self.use_relu:
            x = self.relu(x)
        return x


class CARAFE(nn.Layer):
    def __init__(self, c, c_mid=64, scale=2, k_up=5, k_enc=3):
        """ The unofficial implementation of the CARAFE module.
        The details are in "https://arxiv.org/abs/1905.02188".
        Args:
            c: The channel number of the input and the output.
            c_mid: The channel number after compression.
            scale: The expected upsample scale.
            k_up: The size of the reassembly kernel.
            k_enc: The kernel size of the encoder.
        Returns:
            X: The upsampled feature map.
        """
        super(CARAFE, self).__init__()
        self.scale = scale
        self.k_up = k_up

        self.comp = ConvBNReLU(c, c_mid, kernel_size=1, stride=1, 
                               padding=0, dilation=1)
        self.enc = ConvBNReLU(c_mid, (scale*k_up)**2, kernel_size=k_enc, 
                              stride=1, padding=k_enc//2, dilation=1, 
                              use_relu=False)
        self.pix_shf = nn.PixelShuffle(scale)

        self.upsmp = nn.Upsample(scale_factor=scale, mode='nearest')


    def forward(self, X):
        b, c, h, w = X.shape     # b, c, h, w          [1 16 24 24]
        h_, w_ = h * self.scale, w * self.scale
        
        W = self.comp(X)         # b * m * h * w       [1, 64, 24, 24]
        W = self.enc(W)          # b * 100 * h * w     [1, 100, 24, 24]
        W = self.pix_shf(W)      # b * 25 * h_ * w_    [1, 25, 48, 48]
        W = F.softmax(W, axis=1) # b * 25 * h_ * w_    [1, 25, 48, 48]


        X = self.upsmp(X)        # b * c * h_ * w_     [1, 16, 48, 48]
        X = F.unfold(X, kernel_sizes=self.k_up, 
                    paddings=self.k_up//2*self.scale, 
                    dilations=self.scale)  # b * 25c * h_w_   [1, 400, 2304]   
        X = paddle.reshape(X, (b, c, -1, h_, w_)) # b * 25 * c * h_ * w_ [1, 16, 25, 48, 48]

        # W * X
        # [1, 25, 48, 48] * [1, 16, 25, 48, 48] = [1, 16, 48, 48]
        X = paddle.matmul(W, X)
        return X



if __name__ == '__main__':
    # carafe = CARAFE(16)
    # x = paddle.randn((1, 16, 24, 24))
    # out = carafe(x)
    # print(out.numpy().shape)

    x = paddle.randn((4,4))
    y = paddle.randn((4,4))
    out = paddle.matmul(x, y)
    print(out.shape)